from liger_kernel.transformers import apply_liger_kernel_to_llama

def init_liger(output_loss = True):
    apply_liger_kernel_to_llama(
        rms_norm=False,
        rope=False,
        fused_linear_cross_entropy= not output_loss,
        cross_entropy=output_loss,
        swiglu=False,
    )
